{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "import argparse\n",
    "import copy\n",
    "import json\n",
    "import os\n",
    "import pickle\n",
    "import random\n",
    "import re\n",
    "import sys\n",
    "from collections import defaultdict, deque\n",
    "from typing import List\n",
    "import re\n",
    "from typing import Any, List\n",
    "import ipydagred3\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import openai\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "from SCAN_TOT_utils import *\n",
    "\n",
    "#os.environ[\"http_proxy\"] = \"http://localhost:7890\"\n",
    "#os.environ[\"https_proxy\"] = \"http://localhost:7890\"\n",
    "setOpenAi(keyid = 4)\n",
    "\n",
    "GPT_MODEL = \"gpt-4-turbo\"\n",
    "GPT_MODEL = \"gpt-4o-mini\"\n",
    "GPT_MODEL = \"gpt-3.5-turbo\"\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_instruction(instruction):\n",
    "    parts = instruction.split('OUT:')\n",
    "    question = parts[0].replace('IN:', '').strip()\n",
    "    answer = parts[1].strip()\n",
    "    return question, answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "tasks = []\n",
    "solutions = []\n",
    "count = 0\n",
    "N =100\n",
    "with open(\"/Users/natehu/Desktop/Tsinghua Research/TaDe/v1/scan_tasks.txt\", 'r', encoding= 'utf-8') as file:\n",
    "    for line in file:            \n",
    "            question, actions = split_instruction(line.strip())            \n",
    "            tasks.append(question)\n",
    "            actions = [action.replace(\"I_\", \"\") for action in actions.split()]\n",
    "            solutions.append(actions)\n",
    "            if count == N:\n",
    "                break\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "jump left twice after turn around right twice\n",
      "['TURN_RIGHT', 'TURN_RIGHT', 'TURN_RIGHT', 'TURN_RIGHT', 'TURN_RIGHT', 'TURN_RIGHT', 'TURN_RIGHT', 'TURN_RIGHT', 'TURN_LEFT', 'JUMP', 'TURN_LEFT', 'JUMP']\n"
     ]
    }
   ],
   "source": [
    "i = 8\n",
    "\n",
    "question = tasks[i]\n",
    "golden_answer = solutions[i]\n",
    "\n",
    "print(question)\n",
    "print(golden_answer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total tokens: 706\n",
      "['jump left', 'jump left twice', 'turn right', 'turn around right', 'turn around right twice']\n",
      "{1: 'jump left', 2: 'jump left twice', 3: 'turn right', 4: 'turn around right', 5: 'turn around right twice'}\n"
     ]
    }
   ],
   "source": [
    "\n",
    "decompose_steps = decompose_sql(question)\n",
    "steps, steps_dict = convert_steps_to_format(decompose_steps)\n",
    "num_steps = len(steps)\n",
    "print(steps)\n",
    "print(steps_dict)  # 只是加了一个问题的编号而已."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total tokens: 1380\n",
      "total tokens: 1416\n",
      "total tokens: 1383\n",
      "total tokens: 1419\n",
      "total tokens: 1383\n",
      "total tokens: 1419\n",
      "total tokens: 1468\n",
      "total tokens: 1504\n",
      "total tokens: 1470\n",
      "total tokens: 1506\n",
      "total tokens: 1468\n",
      "total tokens: 1504\n",
      "total tokens: 1471\n",
      "total tokens: 1507\n",
      "total tokens: 1471\n",
      "total tokens: 1507\n",
      "total tokens: 1471\n",
      "total tokens: 1507\n",
      "total tokens: 1508\n",
      "total tokens: 1544\n",
      "total tokens: 1508\n",
      "total tokens: 1544\n",
      "total tokens: 1508\n",
      "total tokens: 1544\n",
      "total tokens: 1515\n",
      "total tokens: 1551\n",
      "total tokens: 1510\n",
      "total tokens: 1546\n",
      "total tokens: 1512\n",
      "total tokens: 1548\n",
      "total tokens: 1555\n",
      "total tokens: 1591\n",
      "total tokens: 1555\n",
      "total tokens: 1591\n",
      "total tokens: 1555\n",
      "total tokens: 1591\n",
      "total tokens: 1555\n",
      "total tokens: 1591\n",
      "total tokens: 1555\n",
      "total tokens: 1591\n",
      "total tokens: 1555\n",
      "total tokens: 1591\n",
      "total tokens: 1607\n",
      "total tokens: 1643\n",
      "total tokens: 1607\n",
      "total tokens: 1643\n",
      "total tokens: 1607\n",
      "total tokens: 1643\n",
      "total tokens: 1607\n",
      "total tokens: 1643\n",
      "total tokens: 1607\n",
      "total tokens: 1643\n",
      "total tokens: 1607\n",
      "total tokens: 1643\n",
      "2\n",
      "(24, ['\"jump left\" outputs \"TURN LEFT\" + \"JUMP\".', 'Sub-problem-Id: 1; Sub-problem: jump left twice; Answer: (\"TURN LEFT\" + \"JUMP\") * 2.', 'Sub-problem-Id: 2; Sub-problem: turn right; Answer: \"TURN RIGHT\".', 'Sub-problem-Id: 3; Sub-problem: turn around right; Answer: \"TURN RIGHT\" * 4.', 'Sub-problem-Id: 4; Sub-problem: turn around right twice; Answer: (\"TURN RIGHT\" * 4) * 2.'])\n",
      "(24, ['\"jump left\" outputs \"TURN LEFT\" + \"JUMP\".', 'Sub-problem-Id: 1; Sub-problem: jump left twice; Answer: (\"TURN LEFT\" + \"JUMP\") * 2.', 'Sub-problem-Id: 2; Sub-problem: turn right; Answer: \"TURN RIGHT\".', 'Sub-problem-Id: 3; Sub-problem: turn around right; Answer: \"TURN RIGHT\" * 4.', 'Sub-problem-Id: 4; Sub-problem: turn around right twice; Answer: (\"TURN RIGHT\" * 4) * 2.'])\n"
     ]
    }
   ],
   "source": [
    "# TOT 求解\n",
    "N = 3  # 每个子问题进行N次proposal\n",
    "M = 2  # 通过评估选出M个最好的proposal\n",
    "\n",
    "solution = []\n",
    "\n",
    "for i in range(num_steps):\n",
    "    subtask = steps[i]\n",
    "    sys_q = f\"\"\"\n",
    "\n",
    "    There is a natural language instruction representing a sequence of actions. I need you to translate this sentence from natural language into a standardized meta-action sequence.\"\n",
    "    Here is the instruction:\\n{question}\n",
    "\n",
    "    I have broken this instruction down into some smaller instructions. I will assign you sub-instructions one by one, and provide the results of the previous sub-instructions as a reference for your reasoning.\n",
    "    Please organize your reasoning according to the combination and progression of actions.\n",
    "\n",
    "    For your reference, 13 examples for translation together with the corresponding explanations are as follows:\n",
    "\n",
    "    Q: \"turn left\"\n",
    "    A: \"turn left\" outputs \"TURN LEFT\".\n",
    "\n",
    "    Q: \"turn right\"\n",
    "    A: \"turn right\" outputs \"TURN RIGHT\".\n",
    "\n",
    "    Q: \"jump left\"\n",
    "    A: The output of “jump left” concatenates: the output of “turn left”, the output of “jump”. “turn left” outputs “TURN LEFT”. “jump” outputs “JUMP”. So concatenating the output of “turn left” and the output of “jump” leads to “TURN LEFT” + “JUMP”. So the output of “jump left” is “TURN LEFT” + “JUMP”.\n",
    "\n",
    "    Q: \"run right\"\n",
    "    A: The output of \"run right\" concatenates: the output of \"turn right\", the output of \"run\". \"turn right\" outputs \"TURN RIGHT\". \"run\" outputs \"RUN\". So concatenating the output of \"turn right\" and the output of \"run\" leads to \"TURN RIGHT\" + \"RUN\". So the output of \"run right\" is \"TURN RIGHT\" + \"RUN\".\n",
    "\n",
    "    Q: \"look twice\"\n",
    "    A: The output of \"look twice\" concatenates: the output of \"look\", the output of \"look\". \"look\" outputs \"LOOK\". So repeating the output of \"look\" two times leads to \"LOOK\" * 2. So the output of \"look twice\" is \"LOOK\" * 2.\n",
    "\n",
    "    Q: \"run and look twice\"\n",
    "    A: The output of \"run and look twice\" concate+nates: the output of \"run\", the output of \"look twice\". \"run\" outputs \"RUN\". \"look twice\" outputs \"LOOK\" * 2. So concatenating the output of \"run\" and the output of \"look twice\" leads to \"RUN\" + \"LOOK\" * 2. So the output of \"run and look twice\" is \"RUN\" + \"LOOK\" * 2.\n",
    "\n",
    "    Q: \"jump right thrice\"\n",
    "    A: The output of \"jump right thrice\" concatenates: the output of \"jump right\", the output of \"jump right\", the output of \"jump right\". \"jump right\" outputs \"TURN RIGHT\" + \"JUMP\". So repeating the output of \"jump right\" three times leads to (\"TURN RIGHT\" + \"JUMP\") * 3. So the output of \"jump right thrice\" is (\"TURN RIGHT\" + \"JUMP\") * 3.\n",
    "\n",
    "    Q: \"walk after run\"\n",
    "    A: The output of \"walk after run\" concatenates: the output of \"run\", the output of \"walk\". \"run\" outputs \"RUN\". \"walk\" outputs \"WALK\". So concatenating the output of \"run\" and the output of \"walk\" leads to \"RUN\" + \"WALK\". So the output of \"walk after run\" is \"RUN\" + \"WALK\".\n",
    "\n",
    "    Q: \"turn opposite left\"\n",
    "    A: The output of \"turn opposite left\" concatenates: the output of \"turn left\", the output of \"turn left\". \"turn left\" outputs \"TURN LEFT\". So repeating the output of \"turn left\" twice leads to \"TURN LEFT\" * 2. So the output of \"turn opposite left\" is \"TURN LEFT\" * 2.\n",
    "\n",
    "    Q: \"turn around left\"\n",
    "    A: The output of \"turn around left\" concatenates: the output of \"turn left\", the output of \"turn left\", the output of \"turn left\", the output of \"turn left\". \"turn left\" outputs \"TURN LEFT\". So repeating the output of \"turn left\" four times leads to \"TURN LEFT\" * 4. So the output of \"turn around left\" is \"TURN LEFT\" * 4. Q: \"turn opposite right\" A: The output of \"turn opposite right\" concatenates: the output of \"turn right\", the output of \"turn right\". \"turn right\" outputs \"TURN RIGHT\". So repeating the output of \"turn right\" twice leads to \"TURN RIGHT\" * 2. So the output of \"turn opposite right\" is \"TURN RIGHT\" * 2.\n",
    "\n",
    "    Q: \"turn around right\"\n",
    "    A: The output of \"turn around right\" concatenates: the output of \"turn right\", the output of \"turn right\", the output of \"turn right\", the output of \"turn right\". \"turn right\" outputs \"TURN RIGHT\". So repeating the output of \"turn right\" four times leads to \"TURN RIGHT\" * 4. So the output of \"turn around right\" is \"TURN RIGHT\" * 4.\n",
    "\n",
    "    Q: \"walk opposite left\"\n",
    "    A: The output of \"walk opposite left\" concatenates: the output of \"turn opposite left\", the output of \"walk\". \"turn opposite left\" outputs \"TURN LEFT\" * 2. \"walk\" outputs \"WALK\". So concatenating the output of \"turn opposite left\" and the output of \"walk\" leads to \"TURN LEFT\" * 2 + \"WALK\". So the output of \"walk opposite left\" is \"TURN LEFT\" * 2 + \"WALK\".\n",
    "\n",
    "    Q: \"walk around left\"\n",
    "    A: The output of \"walk around left\" concatenates: the output of \"walk left\", the output of \"walk left\", the output of \"walk left\", the output of \"walk left\". \"walk left\" outputs \"TURN LEFT\" + \"WALK\". So repeating the output of \"walk around left\" four times leads to (\"TURN LEFT\" + \"WALK\") * 4. So the output of \"walk around left\" is (\"TURN LEFT\" + \"WALK\") * 4.\n",
    "\n",
    "    Please pay attention to the use of parentheses.\n",
    "\"\"\"  \n",
    "\n",
    "    subask = f\"\"\"\\nThe sub-instruction to be converted now is:: {subtask}\n",
    "Based on the information above, please provide a concise and clear answer\"\"\"\n",
    "    \n",
    "    if len(solution)==0:\n",
    "        # 第一个子问题\n",
    "        query = subask\n",
    "        Q = [{'role':'system', 'content':sys_q},\n",
    "            {'role':'user', 'content':query},]\n",
    "        for n in range(N):  # 一个子问题提问N次,获取N个解\n",
    "            result = askChatGPT(Q, model=GPT_MODEL, temperature=1)\n",
    "            eval_Q = Q + [{'role':'assistant', 'content':result}]\n",
    "            eval_Q = eval_Q + [{'role':'user', 'content':\"Please provide a confidence rating for the accuracy of this solution, on a scale from 1 to 5. Only output the number.\"}]\n",
    "            score = askChatGPT(eval_Q, model=GPT_MODEL, temperature=1)\n",
    "            score = int(score)\n",
    "            \n",
    "            solution.append((score, [result]))  # 维护一整条推理路径\n",
    "            \n",
    "        solution = sorted(solution, key=lambda x: x[0])\n",
    "        solution = solution[:M]  # 剪枝\n",
    "    else:\n",
    "        temp_solution = []\n",
    "        for m in range(M):  # 因为剪枝动态维护M个推理路径\n",
    "            answersSoFar = f\"\"\"\\nSo far, the answers to the preceding sub-instructions are as follows: The format is Sub-problem-Id: xxx; Sub-problem: xxx; Answer: xxx.\"\"\"\n",
    "            for index, value in enumerate(solution[m][1]):\n",
    "                try:\n",
    "                    answersSoFar += f\"\"\"\\nSub-problem-Id: {index}; Sub-problem: {steps[index]}; Answer: {value}.\"\"\"\n",
    "                except:\n",
    "                    print('warning')\n",
    "                    print(index)\n",
    "                    print(len(solution[m][1]))\n",
    "                    print(len(steps))\n",
    "                    sys.exit(0)\n",
    "            query = answersSoFar+subask\n",
    "            Q = [{'role':'system', 'content':sys_q},\n",
    "                 {'role':'user', 'content':query},]\n",
    "            for n in range(N):  # 一个子问题提问N次,获取N个解\n",
    "                result = askChatGPT(Q, model=GPT_MODEL, temperature=1)\n",
    "                eval_Q = Q + [{'role':'assistant', 'content':result}]\n",
    "                eval_Q = eval_Q + [{'role':'user', 'content':\"Please provide a confidence rating for the accuracy of this solution, on a scale from 1 to 5. Only output the number.\"}]\n",
    "                score = askChatGPT(eval_Q, model=GPT_MODEL, temperature=1)\n",
    "                score = int(score)\n",
    "                \n",
    "                temp_solution.append((solution[m][0]+score, solution[m][1]+[result]))  # 路径score累加\n",
    "        \n",
    "        # print(len(temp_solution))  # 此时temp_solution中应该有M*N种推理路径\n",
    "        solution = sorted(temp_solution, key=lambda x: x[0])\n",
    "        solution = solution[:M]  # 剪枝 M*N->M\n",
    "\n",
    "print(len(solution))\n",
    "printSeq(solution)\n",
    "# 用额外的一次query再问一下最终的答案\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(24,\n",
       "  ['\"jump left\" outputs \"TURN LEFT\" + \"JUMP\".',\n",
       "   'Sub-problem-Id: 1; Sub-problem: jump left twice; Answer: (\"TURN LEFT\" + \"JUMP\") * 2.',\n",
       "   'Sub-problem-Id: 2; Sub-problem: turn right; Answer: \"TURN RIGHT\".',\n",
       "   'Sub-problem-Id: 3; Sub-problem: turn around right; Answer: \"TURN RIGHT\" * 4.',\n",
       "   'Sub-problem-Id: 4; Sub-problem: turn around right twice; Answer: (\"TURN RIGHT\" * 4) * 2.']),\n",
       " (24,\n",
       "  ['\"jump left\" outputs \"TURN LEFT\" + \"JUMP\".',\n",
       "   'Sub-problem-Id: 1; Sub-problem: jump left twice; Answer: (\"TURN LEFT\" + \"JUMP\") * 2.',\n",
       "   'Sub-problem-Id: 2; Sub-problem: turn right; Answer: \"TURN RIGHT\".',\n",
       "   'Sub-problem-Id: 3; Sub-problem: turn around right; Answer: \"TURN RIGHT\" * 4.',\n",
       "   'Sub-problem-Id: 4; Sub-problem: turn around right twice; Answer: (\"TURN RIGHT\" * 4) * 2.'])]"
      ]
     },
     "execution_count": 49,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "solution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'role': 'user', 'content': 'There is a natural language command representing a sequence of actions:\\njump left twice after turn around right twice\\n\\nI have broken this action down into a series of smaller pieces and each sub-problem is solved.\\nThe sub-problems and their corresponding answers are as follows. \\nSub-problem-Id: 0; Sub-problem: jump left; Answer: \"jump left\" outputs \"TURN LEFT\" + \"JUMP\"..\\nSub-problem-Id: 1; Sub-problem: jump left twice; Answer: Sub-problem-Id: 1; Sub-problem: jump left twice; Answer: (\"TURN LEFT\" + \"JUMP\") * 2..\\nSub-problem-Id: 2; Sub-problem: turn right; Answer: Sub-problem-Id: 2; Sub-problem: turn right; Answer: \"TURN RIGHT\"..\\nSub-problem-Id: 3; Sub-problem: turn around right; Answer: Sub-problem-Id: 3; Sub-problem: turn around right; Answer: \"TURN RIGHT\" * 4..\\nSub-problem-Id: 4; Sub-problem: turn around right twice; Answer: Sub-problem-Id: 4; Sub-problem: turn around right twice; Answer: (\"TURN RIGHT\" * 4) * 2..\\n Given the results from the sub problems and their answer, so what is the final answer? (Try to concatenate some of the sub-problems  answers to get the final answer.)'}]\n"
     ]
    }
   ],
   "source": [
    "user_q = f\"\"\"There is a natural language command representing a sequence of actions:\\n{question}\n",
    "\n",
    "I have broken this action down into a series of smaller pieces and each sub-problem is solved.\n",
    "The sub-problems and their corresponding answers are as follows. \"\"\"\n",
    "\n",
    "for index, value in enumerate(solution[0][1]):  # 这里仅仅使用了最终得分最高的1条路径来总结得到final answer\n",
    "    user_q += f\"\"\"\\nSub-problem-Id: {index}; Sub-problem: {steps[index]}; Answer: {value}.\"\"\"\n",
    "\n",
    "user_q += \"\\n Given the results from the sub problems and their answer, so what is the final answer? (Try to concatenate some of the sub-problems  answers to get the final answer.)\"\n",
    "\n",
    "final_Q = [{'role':'user', 'content':user_q}]\n",
    "\n",
    "\n",
    "print(final_Q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total tokens: 350\n"
     ]
    }
   ],
   "source": [
    "finalResult = askChatGPT(final_Q, model=GPT_MODEL, temperature=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'jump left twice after turn around right twice'"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "question"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'The final answer is: (\"TURN RIGHT\" * 4) * 2 + (\"TURN LEFT\" + \"JUMP\") * 2.'"
      ]
     },
     "execution_count": 63,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "finalResult"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total tokens: 568\n"
     ]
    }
   ],
   "source": [
    "actionSeq = sentenceRes2Actions(finalResult)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'TURN_RIGHT TURN_RIGHT TURN_RIGHT TURN_RIGHT TURN_RIGHT TURN_RIGHT TURN_RIGHT TURN_RIGHT TURN_LEFT JUMP TURN_LEFT JUMP'"
      ]
     },
     "execution_count": 65,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "actionSeq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [],
   "source": [
    "actionList = actionSeq.split()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['TURN_RIGHT',\n",
       " 'TURN_RIGHT',\n",
       " 'TURN_RIGHT',\n",
       " 'TURN_RIGHT',\n",
       " 'TURN_RIGHT',\n",
       " 'TURN_RIGHT',\n",
       " 'TURN_RIGHT',\n",
       " 'TURN_RIGHT',\n",
       " 'TURN_LEFT',\n",
       " 'JUMP',\n",
       " 'TURN_LEFT',\n",
       " 'JUMP']"
      ]
     },
     "execution_count": 67,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "actionList"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['TURN_RIGHT',\n",
       " 'TURN_RIGHT',\n",
       " 'TURN_RIGHT',\n",
       " 'TURN_RIGHT',\n",
       " 'TURN_RIGHT',\n",
       " 'TURN_RIGHT',\n",
       " 'TURN_RIGHT',\n",
       " 'TURN_RIGHT',\n",
       " 'TURN_LEFT',\n",
       " 'JUMP',\n",
       " 'TURN_LEFT',\n",
       " 'JUMP']"
      ]
     },
     "execution_count": 68,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "golden_answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 70,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "actionList == golden_answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {},
   "outputs": [],
   "source": [
    "actionList = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "incorrect\n"
     ]
    }
   ],
   "source": [
    "if actionList == golden_answer:\n",
    "    \n",
    "    print('correct')\n",
    "    \n",
    "elif actionList != golden_answer:\n",
    "    \n",
    "    print('incorrect')\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### wode"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TOT 第二种求解方式\n",
    "\n",
    "N = 3  # 先进性N次任务分解，然后 N次做题，最后评估\n",
    "\n",
    "\n",
    "# 任务分解阶段\n",
    "decompositions = []\n",
    "for i in range(N):\n",
    "    plan, Q = decompose_sql(question, type)\n",
    "    result, result_dict = convert_steps_to_format(plan)\n",
    "    \n",
    "    eval_Q = Q + [{'role':'assistant', 'content':f\"\"\"The plan is : {result}\"\"\"}]\n",
    "    eval_Q = eval_Q + [{'role':'user', 'content':\"Please provide a confidence rating for the accuracy of this step to step plan, on a scale from 1 to 10. Only output the number.\"}]\n",
    "    printSeq(eval_Q)\n",
    "    score = askChatGPT(eval_Q, model=GPT_MODEL, temperature=1)\n",
    "    score = int(score)\n",
    "    \n",
    "    decompositions.append((score, result_dict))\n",
    "\n",
    "# 选择最佳任务分解方案\n",
    "task_decompositions = sorted(decompositions, key=lambda x: x[0], reverse=True)\n",
    "best_task_decomposition = task_decompositions[0][1] \n",
    "\n",
    "print(\"******************************解题阶段******************************\")\n",
    "# 解题阶段\n",
    "solutions = []\n",
    "for i in range(N):\n",
    "    question_input = f\"\"\"\n",
    "There is a math_problem. I need you to solve it and give an answer.\n",
    "Here is the problem:\\n{question}\n",
    "\n",
    "I have broken this math problem down into a series of smaller problems. Here is the list of sub-problems: {best_task_decomposition}\n",
    "Please solve the problem and respond according to mathematical logic.\n",
    "    \n",
    "    \"\"\"\n",
    "    query = {\"role\": \"user\", \"content\": question_input}\n",
    "    Q_jie = [query]\n",
    "    sol = askChatGPT(Q_jie, model=GPT_MODEL, temperature=1)\n",
    "    sol_eval_Q = Q_jie + [{'role':'assistant', 'content':sol}]\n",
    "    sol_eval_Q = sol_eval_Q + [{'role':'user', 'content':\"Please provide a confidence rating for the accuracy of this solution, on a scale from 1 to 10. Only output the number.\"}]\n",
    "    printSeq(sol_eval_Q)\n",
    "    sol_score = askChatGPT(sol_eval_Q, model=GPT_MODEL, temperature=1)\n",
    "    sol_score = int(sol_score)\n",
    "    \n",
    "    solutions.append((sol_score, sol))\n",
    "\n",
    "# 选择最佳解答\n",
    "sub_solutions = sorted(solutions, key=lambda x: x[0], reverse=True)\n",
    "best_solution = sub_solutions[0][1]  # 选择最好的子问题解\n",
    "\n",
    "print(\"best_solution: \", best_solution)\n",
    "\n",
    "print(\"******************************评估结果******************************\")\n",
    "\n",
    "last_query = [{\n",
    "    \"role\": \"user\",\n",
    "    \"content\": f\"\"\" \n",
    "    Given the problem and step-by-step plan and the solution as following: {best_solution}\n",
    "    \n",
    "    What is the final answer?\n",
    "    Please give the final answer without any additional explanation or clarification.\"\"\"\n",
    "}]\n",
    "\n",
    "finalResult = askChatGPT(last_query, model=GPT_MODEL, temperature=1)\n",
    "print(finalResult)\n",
    "\n",
    "\n",
    "# 让大语言模型来判断有没有回答正确\n",
    "judgeAnswer = {'role':'user', 'content':f\"\"\"Here is a math problem with a standard answer and a student's solution. Please help me determine if the student's solution is correct. If the numerical value are same, then it is correct.\n",
    "\n",
    "Problem: {question}\n",
    "\n",
    "Standard answer: {gold_answer}\n",
    "\n",
    "Answer: {finalResult}\n",
    "\n",
    "If the student's answer is correct, just output True; otherwise, just output False.\n",
    "No explanation is required.\n",
    "\"\"\"}\n",
    "\n",
    "Q_judge = [judgeAnswer]\n",
    "ifcorrect = askChatGPT(Q_judge, model=GPT_MODEL, temperature=1)  # 要么是True, 要么是False\n",
    "\n",
    "if 'True' in ifcorrect:\n",
    "    print('correct')\n",
    "    success = True  # 任务未受中断,完整地结束了,所以标记为成功\n",
    "elif 'False' in ifcorrect:\n",
    "    print('error')\n",
    "    success = True  # 任务未受中断,完整地结束了,所以标记为成功                       \n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
